{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 训练唤醒词模型"
   ],
   "metadata": {
    "colab_type": "text",
    "id": "pO4-CY_TCZZS"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "This notebook demonstrates how to train a 20 kB [Simple Audio Recognition](https://www.tensorflow.org/tutorials/sequences/audio_recognition) model to recognize keywords in speech.\n",
    "\n",
    "The model created in this notebook is used in the [micro_speech](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/micro_speech) example for [TensorFlow Lite for MicroControllers](https://www.tensorflow.org/lite/microcontrollers/overview).\n",
    "\n"
   ],
   "metadata": {
    "colab_type": "text",
    "id": "BaFfr7DHRmGF"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "**Training is much faster using GPU acceleration.** Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and set **Hardware accelerator: GPU**. Training 15,000 iterations will take 1.5 - 2 hours on a GPU runtime.\n",
    "\n",
    "## 模型配置\n",
    "\n",
    "**MODIFY** the following constants for your specific use case."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "XaVtYN4nlCft"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# A comma-delimited list of the words you want to train for.\n",
    "# The options are: yes,no,up,down,left,right,on,off,stop,go\n",
    "# All the other words will be used to train an \"unknown\" label and silent\n",
    "# audio data with no spoken words will be used to train a \"silence\" label.\n",
    "WANTED_WORDS = \"on,off\"\n",
    "\n",
    "# The number of steps and learning rates can be specified as comma-separated\n",
    "# lists to define the rate at each stage. For example,\n",
    "# TRAINING_STEPS=12000,3000 and LEARNING_RATE=0.001,0.0001\n",
    "# will run 12,000 training loops in total, with a rate of 0.001 for the first\n",
    "# 8,000, and 0.0001 for the final 3,000.\n",
    "# TRAINING_STEPS = \"15000,3000\"\n",
    "TRAINING_STEPS = \"15000,3000\"\n",
    "LEARNING_RATE = \"0.001,0.0001\"\n",
    "\n",
    "# Calculate the total number of steps, which is used to identify the checkpoint\n",
    "# file name.\n",
    "TOTAL_STEPS = str(sum(map(lambda string: int(string), TRAINING_STEPS.split(\",\"))))\n",
    "\n",
    "# Print the configuration to confirm it\n",
    "print(\"Training these words: %s\" % WANTED_WORDS)\n",
    "print(\"Training steps in each stage: %s\" % TRAINING_STEPS)\n",
    "print(\"Learning rate in each stage: %s\" % LEARNING_RATE)\n",
    "print(\"Total number of training steps: %s\" % TOTAL_STEPS)"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ludfxbNIaegy"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "**DO NOT MODIFY** the following constants as they include filepaths used in this notebook and data that is shared during training and inference."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "gCgeOpvY9pAi"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Calculate the percentage of 'silence' and 'unknown' training samples required\n",
    "# to ensure that we have equal number of samples for each label.\n",
    "number_of_labels = WANTED_WORDS.count(',') + 1\n",
    "number_of_total_labels = number_of_labels + 2 # for 'silence' and 'unknown' label\n",
    "equal_percentage_of_training_samples = int(100.0/(number_of_total_labels))\n",
    "SILENT_PERCENTAGE = equal_percentage_of_training_samples\n",
    "UNKNOWN_PERCENTAGE = equal_percentage_of_training_samples\n",
    "VALIDATION_PERCENTAGE = 10 #10\n",
    "TESTING_PERCENTAGE = 10 # 10\n",
    "\n",
    "print('SILENT_PERCENTAGE : %d' % SILENT_PERCENTAGE)\n",
    "print('UNKNOWN_PERCENTAGE : %d' % UNKNOWN_PERCENTAGE)\n",
    "print('VALIDATION_PERCENTAGE : %d' % VALIDATION_PERCENTAGE)\n",
    "print('TESTING_PERCENTAGE : %d' % TESTING_PERCENTAGE)\n",
    "\n",
    "# Constants which are shared during training and inference\n",
    "PREPROCESS = 'micro'\n",
    "WINDOW_STRIDE = 20\n",
    "MODEL_ARCHITECTURE = 'tiny_conv' # Other options include: single_fc, conv,\n",
    "                      # low_latency_conv, low_latency_svdf, tiny_embedding_conv\n",
    "\n",
    "# Constants used during training only\n",
    "VERBOSITY = 'WARN'\n",
    "EVAL_STEP_INTERVAL = '100' # '1000'\n",
    "SAVE_STEP_INTERVAL = '100' # '1000'\n",
    "BATCH_SIZE = '64' # default 100\n",
    "\n",
    "# Constants for training directories and filepaths\n",
    "DATASET_DIR =  './dataset/'\n",
    "\n",
    "LOGS_DIR = 'logs/'\n",
    "TRAIN_DIR = 'train/' # for training checkpoints and other files.\n",
    "\n",
    "# Constants for inference directories and filepaths\n",
    "import os\n",
    "MODELS_DIR = 'models'\n",
    "if not os.path.exists(MODELS_DIR):\n",
    "  os.mkdir(MODELS_DIR)\n",
    "MODEL_TF = os.path.join(MODELS_DIR, 'model.pb')\n",
    "MODEL_TFLITE = os.path.join(MODELS_DIR, 'model.tflite')\n",
    "FLOAT_MODEL_TFLITE = os.path.join(MODELS_DIR, 'float_model.tflite')\n",
    "MODEL_TFLITE_MICRO = os.path.join(MODELS_DIR, 'model.cc')\n",
    "SAVED_MODEL = os.path.join(MODELS_DIR, 'saved_model')\n",
    "\n",
    "QUANT_INPUT_MIN = 0.0\n",
    "QUANT_INPUT_MAX = 26.0\n",
    "QUANT_INPUT_RANGE = QUANT_INPUT_MAX - QUANT_INPUT_MIN"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Nd1iM1o2ymvA"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 环境安装\n",
    "\n",
    "安装依赖项"
   ],
   "metadata": {
    "colab_type": "text",
    "id": "6rLYpvtg9P4o"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "#%tensorflow_version 1.x\n",
    "import tensorflow as tf"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ed_XpUrU5DvY"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "**DELETE** any old data from previous runs\n"
   ],
   "metadata": {
    "colab_type": "text",
    "id": "T9Ty5mR58E4i"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "#!rm -rf {DATASET_DIR} {LOGS_DIR} {TRAIN_DIR} {MODELS_DIR}\n",
    "!rm -rf {LOGS_DIR} {TRAIN_DIR} {MODELS_DIR}"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "APGx0fEh7hFF"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Clone the TensorFlow Github Repository, which contains the relevant code required to run this tutorial."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "GfEUlfFBizio"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "#!git clone -q --depth 1 https://github.com/tensorflow/tensorflow"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yZArmzT85SLq"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Load TensorBoard to visualize the accuracy and loss as training proceeds.\n"
   ],
   "metadata": {
    "colab_type": "text",
    "id": "nS9swHLSi7Bi"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "#%load_ext tensorboard\n",
    "#%tensorboard --logdir {LOGS_DIR}"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "q4qF1VxP3UE4"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 模型训练\n",
    "\n",
    "The following script downloads the dataset and begin training."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "x1J96Ron-O4R"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "!python ./speech_commands/train.py \\\n",
    "--data_dir={DATASET_DIR} \\\n",
    "--wanted_words={WANTED_WORDS} \\\n",
    "--silence_percentage={SILENT_PERCENTAGE} \\\n",
    "--unknown_percentage={UNKNOWN_PERCENTAGE} \\\n",
    "--validation_percentage={VALIDATION_PERCENTAGE} \\\n",
    "--testing_percentage={TESTING_PERCENTAGE} \\\n",
    "--batch_size={BATCH_SIZE} \\\n",
    "--preprocess={PREPROCESS} \\\n",
    "--window_stride={WINDOW_STRIDE} \\\n",
    "--model_architecture={MODEL_ARCHITECTURE} \\\n",
    "--how_many_training_steps={TRAINING_STEPS} \\\n",
    "--learning_rate={LEARNING_RATE} \\\n",
    "--train_dir={TRAIN_DIR} \\\n",
    "--summaries_dir={LOGS_DIR} \\\n",
    "--verbosity={VERBOSITY} \\\n",
    "--eval_step_interval={EVAL_STEP_INTERVAL} \\\n",
    "--save_step_interval={SAVE_STEP_INTERVAL}"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VJsEZx6lynbY"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 生成tensorflow模型\n",
    "\n",
    "Combine relevant training results (graph, weights, etc) into a single file for inference. This process is known as freezing a model and the resulting model is known as a frozen model/graph, as it cannot be further re-trained after this process."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "XQUJLrdS-ftl"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "!rm -rf {SAVED_MODEL}\n",
    "!python ./speech_commands/freeze.py \\\n",
    "--wanted_words=$WANTED_WORDS \\\n",
    "--window_stride_ms=$WINDOW_STRIDE \\\n",
    "--preprocess=$PREPROCESS \\\n",
    "--model_architecture=$MODEL_ARCHITECTURE \\\n",
    "--start_checkpoint=$TRAIN_DIR$MODEL_ARCHITECTURE'.ckpt-'{TOTAL_STEPS} \\\n",
    "--save_format=saved_model \\\n",
    "--output_file={SAVED_MODEL}"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xyc3_eLh9sAg"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 生成Tensorflow Lite模型\n",
    "\n",
    "Convert the frozen graph into a TensorFlow Lite model, which is fully quantized for use with embedded devices.\n",
    "\n",
    "The following cell will also print the model size, which will be under 20 kilobytes."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "_DBGDxVI-nKG"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import sys\n",
    "# We add this path so we can import the speech processing modules.\n",
    "sys.path.append(\"./speech_commands/\")\n",
    "import input_data\n",
    "import models\n",
    "import numpy as np"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RIitkqvGWmre"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "SAMPLE_RATE = 16000\n",
    "CLIP_DURATION_MS = 1000\n",
    "WINDOW_SIZE_MS = 30.0\n",
    "FEATURE_BIN_COUNT = 40\n",
    "BACKGROUND_FREQUENCY = 0.8\n",
    "BACKGROUND_VOLUME_RANGE = 0.1\n",
    "TIME_SHIFT_MS = 100.0\n",
    "\n",
    "DATA_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'\n",
    "VALIDATION_PERCENTAGE = 10 #10\n",
    "TESTING_PERCENTAGE = 10 # 10"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kzqECqMxgBh4"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "model_settings = models.prepare_model_settings(\n",
    "    len(input_data.prepare_words_list(WANTED_WORDS.split(','))),\n",
    "    SAMPLE_RATE, CLIP_DURATION_MS, WINDOW_SIZE_MS,\n",
    "    WINDOW_STRIDE, FEATURE_BIN_COUNT, PREPROCESS)\n",
    "audio_processor = input_data.AudioProcessor(\n",
    "    DATA_URL, DATASET_DIR,\n",
    "    SILENT_PERCENTAGE, UNKNOWN_PERCENTAGE,\n",
    "    WANTED_WORDS.split(','), VALIDATION_PERCENTAGE,\n",
    "    TESTING_PERCENTAGE, model_settings, LOGS_DIR)"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rNQdAplJV1fz"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "with tf.Session() as sess:\n",
    "  float_converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL)\n",
    "  float_tflite_model = float_converter.convert()\n",
    "  float_tflite_model_size = open(FLOAT_MODEL_TFLITE, \"wb\").write(float_tflite_model)\n",
    "  print(\"Float model is %d bytes\" % float_tflite_model_size)\n",
    "\n",
    "  converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL)\n",
    "  converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
    "  converter.inference_input_type = tf.lite.constants.INT8\n",
    "  converter.inference_output_type = tf.lite.constants.INT8\n",
    "  def representative_dataset_gen():\n",
    "    set_size = audio_processor.set_size('testing') #get test set size\n",
    "    for i in range(set_size): # change 100 to set_size\n",
    "      data, _ = audio_processor.get_data(1, i*1, model_settings,\n",
    "                                         BACKGROUND_FREQUENCY, \n",
    "                                         BACKGROUND_VOLUME_RANGE,\n",
    "                                         TIME_SHIFT_MS,\n",
    "                                         'testing',\n",
    "                                         sess)\n",
    "      flattened_data = np.array(data.flatten(), dtype=np.float32).reshape(1, 1960)\n",
    "      yield [flattened_data]\n",
    "  converter.representative_dataset = representative_dataset_gen\n",
    "  tflite_model = converter.convert()\n",
    "  tflite_model_size = open(MODEL_TFLITE, \"wb\").write(tflite_model)\n",
    "  print(\"Quantized model is %d bytes\" % tflite_model_size)\n"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lBj_AyCh1cC0"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 测试TensorFlow Lite model的精确度\n",
    "\n",
    "Verify that the model we've exported is still accurate, using the TF Lite Python API and our test set."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "EeLiDZTbLkzv"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Helper function to run inference\n",
    "def run_tflite_inference(tflite_model_path, model_type=\"Float\"):\n",
    "  # Load test data\n",
    "  np.random.seed(0) # set random seed for reproducible test results.\n",
    "  with tf.Session() as sess:\n",
    "    test_data, test_labels = audio_processor.get_data(\n",
    "        -1, 0, model_settings, BACKGROUND_FREQUENCY, BACKGROUND_VOLUME_RANGE,\n",
    "        TIME_SHIFT_MS, 'testing', sess)\n",
    "  test_data = np.expand_dims(test_data, axis=1).astype(np.float32)\n",
    "\n",
    "  # Initialize the interpreter\n",
    "  interpreter = tf.lite.Interpreter(tflite_model_path)\n",
    "  interpreter.allocate_tensors()\n",
    "\n",
    "  input_details = interpreter.get_input_details()[0]\n",
    "  output_details = interpreter.get_output_details()[0]\n",
    "\n",
    "  # For quantized models, manually quantize the input data from float to integer\n",
    "  if model_type == \"Quantized\":\n",
    "    input_scale, input_zero_point = input_details[\"quantization\"]\n",
    "    test_data = test_data / input_scale + input_zero_point\n",
    "    test_data = test_data.astype(input_details[\"dtype\"])\n",
    "\n",
    "  correct_predictions = 0\n",
    "  for i in range(len(test_data)):\n",
    "    interpreter.set_tensor(input_details[\"index\"], test_data[i])\n",
    "    interpreter.invoke()\n",
    "    output = interpreter.get_tensor(output_details[\"index\"])[0]\n",
    "    top_prediction = output.argmax()\n",
    "    correct_predictions += (top_prediction == test_labels[i])\n",
    "\n",
    "  print('%s model accuracy is %f%% (Number of test samples=%d)' % (\n",
    "      model_type, (correct_predictions * 100) / len(test_data), len(test_data)))"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wQsEteKRLryJ"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Compute float model accuracy\n",
    "run_tflite_inference(FLOAT_MODEL_TFLITE)\n",
    "\n",
    "# Compute quantized model accuracy\n",
    "run_tflite_inference(MODEL_TFLITE, model_type='Quantized')"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "l-pD52Na6jRa"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 模型转换\n",
    "Convert the TensorFlow Lite model into a C source file that can be loaded by TensorFlow Lite for Microcontrollers."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "dt6Zqbxu-wIi"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Install xxd if it is not available\n",
    "#!apt-get update && apt-get -qq install xxd\n",
    "# Convert to a C source file\n",
    "!xxd -i {MODEL_TFLITE} > {MODEL_TFLITE_MICRO}\n",
    "# Update variable names\n",
    "REPLACE_TEXT = MODEL_TFLITE.replace('/', '_').replace('.', '_')\n",
    "!sed -i 's/'{REPLACE_TEXT}'/g_model/g' {MODEL_TFLITE_MICRO}"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XohZOTjR8ZyE"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 部署到微控制器\n",
    "\n",
    "Follow the instructions in the [micro_speech](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/micro_speech) README.md for [TensorFlow Lite for MicroControllers](https://www.tensorflow.org/lite/microcontrollers/overview) to deploy this model on a specific microcontroller.\n",
    "\n",
    "**Reference Model:** If you have not modified this notebook, you can follow the instructions as is, to deploy the model. Refer to the [`micro_speech/train/models`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/micro_speech/train/models) directory to access the models generated in this notebook.\n",
    "\n",
    "**New Model:** If you have generated a new model to identify different words: (i) Update `kCategoryCount` and `kCategoryLabels` in [`micro_speech/micro_features/micro_model_settings.h`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h) and (ii) Update the values assigned to the variables defined in [`micro_speech/micro_features/model.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/micro_speech/micro_features/model.cc) with values displayed after running the following cell."
   ],
   "metadata": {
    "colab_type": "text",
    "id": "2pQnN0i_-0L2"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Print the C source file\n",
    "!cat {MODEL_TFLITE_MICRO}"
   ],
   "outputs": [],
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eoYyh0VU8pca"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 生成测试音频流\n",
    "用于测试TFLite模型，可选。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "!python ./speech_commands/generate_streaming_test_wav.py \\\n",
    "    --wanted_words='on,off' \\\n",
    "    --data_dir=./my_only_dataset --background_dir=./my_only_dataset/_background_noise_ \\\n",
    "    --background_volume=0.1 --test_duration_seconds=600 \\\n",
    "    --output_audio_file=./tmp/streaming_test.wav \\\n",
    "    --output_labels_file=./tmp/streaming_test_labels.txt"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 测试音频流准确率"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "!python ./speech_commands/test_streaming_accuracy.py \\\n",
    "    --wav=./tmp/streaming_test.wav \\\n",
    "    --ground-truth=./tmp/streaming_test_labels.txt --verbose \\\n",
    "    --model=./models/saved_model/ \\\n",
    "    --labels=./train/tiny_conv_labels.txt \\\n",
    "    --clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \\\n",
    "    --suppression_ms=500 --time_tolerance_ms=1500"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "train_micro_speech_model.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.6 64-bit ('ai': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  },
  "interpreter": {
   "hash": "abb811d8b1be5969553a0033f5a252013244c1763569cb77f5209b62dcc65a34"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}